{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Physionet 2017 | ECG Rhythm Classification\n",
    "## 4. Train Model\n",
    "### Sebastian D. Goodfellow, Ph.D."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Noteboook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import 3rd party libraries\n",
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "# Deep learning libraries\n",
    "import tensorflow as tf\n",
    "\n",
    "# Import local Libraries\n",
    "sys.path.insert(0, r'C:\\Users\\sebastian goodfellow\\Documents\\code\\deeP_ecg')\n",
    "from deepecg.training.utils.plotting.time_series import plot_time_series_widget\n",
    "from deepecg.training.utils.devices.device_check import print_device_counts\n",
    "# from deepecg.training.train.memory.train import train\n",
    "# from deepecg.training.model.memory.model import Model\n",
    "from deepecg.config.config import DATA_DIR\n",
    "\n",
    "# Configure Notebook\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Objective Function\n",
    "# https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy\n",
    "\n",
    "# Global Average Pooling\n",
    "# https://alexisbcook.github.io/2017/global-average-pooling-layers-for-object-localization/\n",
    "# https://github.com/philipperemy/tensorflow-class-activation-mapping/blob/master/class_activation_map.py\n",
    "# https://github.com/AndersonJo/global-average-pooling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Load ECG Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C:\\Users\\sebastian goodfellow\\Documents\\code\\deeP_ecg\\data\\training\\memory\\training_60s.pickle\n"
     ]
    },
    {
     "ename": "EOFError",
     "evalue": "Ran out of input",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mEOFError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-9-dd744a160af8>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      9\u001b[0m \u001b[1;31m# Unpickle\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     10\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'training_60s.pickle'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"rb\"\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0minput_file\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m     \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;31mEOFError\u001b[0m: Ran out of input"
     ]
    }
   ],
   "source": [
    "# Set path\n",
    "path = os.path.join(DATA_DIR, 'training', 'memory')\n",
    "\n",
    "# Set sample rate\n",
    "fs = 300\n",
    "\n",
    "print(os.path.join(path, 'training_60s.pickle'))\n",
    "\n",
    "# Unpickle\n",
    "with open(os.path.join(path, 'training_60s.pickle'), \"rb\") as input_file:\n",
    "    data = pickle.load(input_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train dimensions: (5774, 18000, 1)\n",
      "y_train dimensions: (5774, 1)\n",
      "x_val dimensions: (2475, 18000, 1)\n",
      "y_val dimensions: (2475, 1)\n"
     ]
    }
   ],
   "source": [
    "# Get training data\n",
    "x_train = data['data_train'].values.reshape(data['data_train'].shape[0], data['data_train'].shape[1], 1)\n",
    "y_train = data['labels_train']['label_int'].values.reshape(data['labels_train'].shape[0], 1).astype(int)\n",
    "\n",
    "# Get validation data\n",
    "x_val = data['data_val'].values.reshape(data['data_val'].shape[0], data['data_val'].shape[1], 1)\n",
    "y_val = data['labels_val']['label_int'].values.reshape(data['labels_val'].shape[0], 1).astype(int)\n",
    "\n",
    "# Print dimensions\n",
    "print('x_train dimensions: ' + str(x_train.shape))\n",
    "print('y_train dimensions: ' + str(y_train.shape))\n",
    "print('x_val dimensions: ' + str(x_val.shape))\n",
    "print('y_val dimensions: ' + str(y_val.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train dimensions: (5774, 18000, 1)\n",
      "y_train dimensions: (5774, 1)\n",
      "y_train_1hot dimensions: (5774, 3)\n",
      "x_val dimensions: (2475, 18000, 1)\n",
      "y_val dimensions: (2475, 1)\n",
      "y_val_1hot dimensions: (2475, 3)\n"
     ]
    }
   ],
   "source": [
    "# One hot encoding array dimensions\n",
    "y_train_1hot = one_hot_encoding(labels=y_train.ravel(), classes=len(np.unique(y_train.ravel())))\n",
    "y_val_1hot = one_hot_encoding(labels=y_val.ravel(), classes=len(np.unique(y_val.ravel())))\n",
    "\n",
    "# Print dimensions\n",
    "print('x_train dimensions: ' + str(x_train.shape))\n",
    "print('y_train dimensions: ' + str(y_train.shape))\n",
    "print('y_train_1hot dimensions: ' + str(y_train_1hot.shape))\n",
    "print('x_val dimensions: ' + str(x_val.shape))\n",
    "print('y_val dimensions: ' + str(y_val.shape))\n",
    "print('y_val_1hot dimensions: ' + str(y_val_1hot.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train: Classes: [0 1 2]\n",
      "Train: Count: [3553  531 1690]\n",
      "Val: Classes: [0 1 2]\n",
      "Val: Count: [1523  227  725]\n"
     ]
    }
   ],
   "source": [
    "# Label lookup\n",
    "label_lookup = {'N': 0, 'A': 1, 'O': 2, '~': 3}\n",
    "\n",
    "# Label dimensions\n",
    "print('Train: Classes: ' + str(np.unique(y_train.ravel())))\n",
    "print('Train: Count: ' + str(np.bincount(y_train.ravel())))\n",
    "print('Val: Classes: ' + str(np.unique(y_val.ravel())))\n",
    "print('Val: Count: ' + str(np.bincount(y_val.ravel())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e171fc9e94894277bcf551ee6d9c7215",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(IntSlider(value=2886, description='index', max=5773), Output()), _dom_classes=('widget-i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Label dictionary\n",
    "label_list = ['Normal Sinus Rhythm', 'Atrial Fibrillation', 'Other Rhythm']\n",
    "\n",
    "# PLot times series\n",
    "plot_time_series_widget(time_series=x_train, labels=y_train, fs=fs, label_list=label_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Device Check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Workstation has 1 CPUs.\n",
      "Workstation has 2 GPUs.\n"
     ]
    }
   ],
   "source": [
    "# Get GPU count\n",
    "print_device_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU Assign\n",
      "/gpu:0\n",
      "CPU Assign\n",
      "/gpu:1\n"
     ]
    },
    {
     "ename": "InvalidArgumentError",
     "evalue": "Malformed device specification '//cpu:0'\n\t [[Node: train_op/DeepECG/logits/kernel/Adam_1 = VariableV2[_class=[\"loc:@DeepECG/logits/kernel\"], container=\"\", dtype=DT_FLOAT, shape=[64,3], shared_name=\"\", _device=\"//cpu:0\"]()]]\n\nCaused by op 'train_op/DeepECG/logits/kernel/Adam_1', defined at:\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\runpy.py\", line 193, in _run_module_as_main\n    \"__main__\", mod_spec)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\runpy.py\", line 85, in _run_code\n    exec(code, run_globals)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel_launcher.py\", line 16, in <module>\n    app.launch_new_instance()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\traitlets\\config\\application.py\", line 658, in launch_instance\n    app.start()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\kernelapp.py\", line 486, in start\n    self.io_loop.start()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tornado\\platform\\asyncio.py\", line 127, in start\n    self.asyncio_loop.run_forever()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\asyncio\\base_events.py\", line 422, in run_forever\n    self._run_once()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\asyncio\\base_events.py\", line 1434, in _run_once\n    handle._run()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\asyncio\\events.py\", line 145, in _run\n    self._callback(*self._args)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tornado\\platform\\asyncio.py\", line 117, in _handle_events\n    handler_func(fileobj, events)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tornado\\stack_context.py\", line 276, in null_wrapper\n    return fn(*args, **kwargs)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\zmq\\eventloop\\zmqstream.py\", line 450, in _handle_events\n    self._handle_recv()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\zmq\\eventloop\\zmqstream.py\", line 480, in _handle_recv\n    self._run_callback(callback, msg)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\zmq\\eventloop\\zmqstream.py\", line 432, in _run_callback\n    callback(*args, **kwargs)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tornado\\stack_context.py\", line 276, in null_wrapper\n    return fn(*args, **kwargs)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\kernelbase.py\", line 283, in dispatcher\n    return self.dispatch_shell(stream, msg)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\kernelbase.py\", line 233, in dispatch_shell\n    handler(stream, idents, msg)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\kernelbase.py\", line 399, in execute_request\n    user_expressions, allow_stdin)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\ipkernel.py\", line 208, in do_execute\n    res = shell.run_cell(code, store_history=store_history, silent=silent)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\zmqshell.py\", line 537, in run_cell\n    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 2662, in run_cell\n    raw_cell, store_history, silent, shell_futures)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 2785, in _run_cell\n    interactivity=interactivity, compiler=compiler, result=result)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 2903, in run_ast_nodes\n    if self.run_code(code, result):\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 2963, in run_code\n    exec(code_obj, self.user_global_ns, self.user_ns)\n  File \"<ipython-input-15-a434317e092d>\", line 37, in <module>\n    max_to_keep=max_to_keep\n  File \"C:\\Users\\sebastian goodfellow\\Documents\\code\\dee_ecg\\model\\model.py\", line 48, in __init__\n    self.initialize_graph()\n  File \"C:\\Users\\sebastian goodfellow\\Documents\\code\\dee_ecg\\model\\model.py\", line 56, in initialize_graph\n    self.graph = Graph(network=self.network, save_path=self.save_path, max_to_keep=self.max_to_keep)\n  File \"C:\\Users\\sebastian goodfellow\\Documents\\code\\dee_ecg\\model\\graph.py\", line 55, in __init__\n    self.build_graph()\n  File \"C:\\Users\\sebastian goodfellow\\Documents\\code\\dee_ecg\\model\\graph.py\", line 86, in build_graph\n    global_step=self.global_step)\n  File \"C:\\Users\\sebastian goodfellow\\Documents\\code\\dee_ecg\\model\\graph.py\", line 299, in _run_optimization_step\n    return optimizer.apply_gradients(grads_and_vars=gradients, global_step=global_step)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\optimizer.py\", line 600, in apply_gradients\n    self._create_slots(var_list)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\adam.py\", line 132, in _create_slots\n    self._zeros_slot(v, \"v\", self._name)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\optimizer.py\", line 1150, in _zeros_slot\n    new_slot_variable = slot_creator.create_zeros_slot(var, op_name)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\slot_creator.py\", line 181, in create_zeros_slot\n    colocate_with_primary=colocate_with_primary)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\slot_creator.py\", line 155, in create_slot_with_initializer\n    dtype)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\slot_creator.py\", line 65, in _create_slot_var\n    validate_shape=validate_shape)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 1317, in get_variable\n    constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 1079, in get_variable\n    constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 425, in get_variable\n    constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 394, in _true_getter\n    use_resource=use_resource, constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 786, in _get_single_variable\n    use_resource=use_resource)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 2220, in variable\n    use_resource=use_resource)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 2210, in <lambda>\n    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 2193, in default_variable_creator\n    constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variables.py\", line 235, in __init__\n    constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variables.py\", line 349, in _init_from_args\n    name=name)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\state_ops.py\", line 137, in variable_op_v2\n    shared_name=shared_name)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\gen_state_ops.py\", line 1254, in variable_v2\n    shared_name=shared_name, name=name)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\framework\\op_def_library.py\", line 787, in _apply_op_helper\n    op_def=op_def)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\framework\\ops.py\", line 3392, in create_op\n    op_def=op_def)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\framework\\ops.py\", line 1718, in __init__\n    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access\n\nInvalidArgumentError (see above for traceback): Malformed device specification '//cpu:0'\n\t [[Node: train_op/DeepECG/logits/kernel/Adam_1 = VariableV2[_class=[\"loc:@DeepECG/logits/kernel\"], container=\"\", dtype=DT_FLOAT, shape=[64,3], shared_name=\"\", _device=\"//cpu:0\"]()]]\n",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mInvalidArgumentError\u001b[0m                      Traceback (most recent call last)",
      "\u001b[1;32mC:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m   1321\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1322\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1323\u001b[0m     \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[1;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[0;32m   1304\u001b[0m       \u001b[1;31m# Ensure any changes to the graph are reflected in the runtime.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1305\u001b[1;33m       \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1306\u001b[0m       return self._call_tf_sessionrun(\n",
      "\u001b[1;32mC:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_extend_graph\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m   1339\u001b[0m       \u001b[1;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_graph\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[1;33m:\u001b[0m  \u001b[1;31m# pylint: disable=protected-access\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1340\u001b[1;33m         \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mExtendSession\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1341\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mInvalidArgumentError\u001b[0m: Malformed device specification '//cpu:0'\n\t [[Node: train_op/DeepECG/logits/kernel/Adam_1 = VariableV2[_class=[\"loc:@DeepECG/logits/kernel\"], container=\"\", dtype=DT_FLOAT, shape=[64,3], shared_name=\"\", _device=\"//cpu:0\"]()]]",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[1;31mInvalidArgumentError\u001b[0m                      Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-15-a434317e092d>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     35\u001b[0m     \u001b[0mnetwork_parameters\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mnetwork_parameters\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     36\u001b[0m     \u001b[0msave_path\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msave_path\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 37\u001b[1;33m     \u001b[0mmax_to_keep\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mmax_to_keep\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     38\u001b[0m )\n",
      "\u001b[1;32m~\\Documents\\code\\dee_ecg\\model\\model.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, model_name, network_name, network_parameters, save_path, max_to_keep)\u001b[0m\n\u001b[0;32m     46\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     47\u001b[0m         \u001b[1;31m# Initialize graph\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 48\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minitialize_graph\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     49\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     50\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0minitialize_graph\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Documents\\code\\dee_ecg\\model\\model.py\u001b[0m in \u001b[0;36minitialize_graph\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m     60\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     61\u001b[0m         \u001b[1;31m# Initialize global variables\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 62\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msess\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgraph\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_global\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     63\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     64\u001b[0m         \u001b[1;31m# Save network object\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m    898\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    899\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[1;32m--> 900\u001b[1;33m                          run_metadata_ptr)\n\u001b[0m\u001b[0;32m    901\u001b[0m       \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    902\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[1;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m   1133\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1134\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[1;32m-> 1135\u001b[1;33m                              feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[0;32m   1136\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1137\u001b[0m       \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m   1314\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1315\u001b[0m       return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[1;32m-> 1316\u001b[1;33m                            run_metadata)\n\u001b[0m\u001b[0;32m   1317\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1318\u001b[0m       \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m   1333\u001b[0m         \u001b[1;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1334\u001b[0m           \u001b[1;32mpass\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1335\u001b[1;33m       \u001b[1;32mraise\u001b[0m \u001b[0mtype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnode_def\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mop\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1336\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1337\u001b[0m   \u001b[1;32mdef\u001b[0m \u001b[0m_extend_graph\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mInvalidArgumentError\u001b[0m: Malformed device specification '//cpu:0'\n\t [[Node: train_op/DeepECG/logits/kernel/Adam_1 = VariableV2[_class=[\"loc:@DeepECG/logits/kernel\"], container=\"\", dtype=DT_FLOAT, shape=[64,3], shared_name=\"\", _device=\"//cpu:0\"]()]]\n\nCaused by op 'train_op/DeepECG/logits/kernel/Adam_1', defined at:\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\runpy.py\", line 193, in _run_module_as_main\n    \"__main__\", mod_spec)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\runpy.py\", line 85, in _run_code\n    exec(code, run_globals)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel_launcher.py\", line 16, in <module>\n    app.launch_new_instance()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\traitlets\\config\\application.py\", line 658, in launch_instance\n    app.start()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\kernelapp.py\", line 486, in start\n    self.io_loop.start()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tornado\\platform\\asyncio.py\", line 127, in start\n    self.asyncio_loop.run_forever()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\asyncio\\base_events.py\", line 422, in run_forever\n    self._run_once()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\asyncio\\base_events.py\", line 1434, in _run_once\n    handle._run()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\asyncio\\events.py\", line 145, in _run\n    self._callback(*self._args)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tornado\\platform\\asyncio.py\", line 117, in _handle_events\n    handler_func(fileobj, events)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tornado\\stack_context.py\", line 276, in null_wrapper\n    return fn(*args, **kwargs)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\zmq\\eventloop\\zmqstream.py\", line 450, in _handle_events\n    self._handle_recv()\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\zmq\\eventloop\\zmqstream.py\", line 480, in _handle_recv\n    self._run_callback(callback, msg)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\zmq\\eventloop\\zmqstream.py\", line 432, in _run_callback\n    callback(*args, **kwargs)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tornado\\stack_context.py\", line 276, in null_wrapper\n    return fn(*args, **kwargs)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\kernelbase.py\", line 283, in dispatcher\n    return self.dispatch_shell(stream, msg)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\kernelbase.py\", line 233, in dispatch_shell\n    handler(stream, idents, msg)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\kernelbase.py\", line 399, in execute_request\n    user_expressions, allow_stdin)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\ipkernel.py\", line 208, in do_execute\n    res = shell.run_cell(code, store_history=store_history, silent=silent)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\ipykernel\\zmqshell.py\", line 537, in run_cell\n    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 2662, in run_cell\n    raw_cell, store_history, silent, shell_futures)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 2785, in _run_cell\n    interactivity=interactivity, compiler=compiler, result=result)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 2903, in run_ast_nodes\n    if self.run_code(code, result):\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 2963, in run_code\n    exec(code_obj, self.user_global_ns, self.user_ns)\n  File \"<ipython-input-15-a434317e092d>\", line 37, in <module>\n    max_to_keep=max_to_keep\n  File \"C:\\Users\\sebastian goodfellow\\Documents\\code\\dee_ecg\\model\\model.py\", line 48, in __init__\n    self.initialize_graph()\n  File \"C:\\Users\\sebastian goodfellow\\Documents\\code\\dee_ecg\\model\\model.py\", line 56, in initialize_graph\n    self.graph = Graph(network=self.network, save_path=self.save_path, max_to_keep=self.max_to_keep)\n  File \"C:\\Users\\sebastian goodfellow\\Documents\\code\\dee_ecg\\model\\graph.py\", line 55, in __init__\n    self.build_graph()\n  File \"C:\\Users\\sebastian goodfellow\\Documents\\code\\dee_ecg\\model\\graph.py\", line 86, in build_graph\n    global_step=self.global_step)\n  File \"C:\\Users\\sebastian goodfellow\\Documents\\code\\dee_ecg\\model\\graph.py\", line 299, in _run_optimization_step\n    return optimizer.apply_gradients(grads_and_vars=gradients, global_step=global_step)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\optimizer.py\", line 600, in apply_gradients\n    self._create_slots(var_list)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\adam.py\", line 132, in _create_slots\n    self._zeros_slot(v, \"v\", self._name)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\optimizer.py\", line 1150, in _zeros_slot\n    new_slot_variable = slot_creator.create_zeros_slot(var, op_name)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\slot_creator.py\", line 181, in create_zeros_slot\n    colocate_with_primary=colocate_with_primary)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\slot_creator.py\", line 155, in create_slot_with_initializer\n    dtype)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\training\\slot_creator.py\", line 65, in _create_slot_var\n    validate_shape=validate_shape)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 1317, in get_variable\n    constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 1079, in get_variable\n    constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 425, in get_variable\n    constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 394, in _true_getter\n    use_resource=use_resource, constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 786, in _get_single_variable\n    use_resource=use_resource)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 2220, in variable\n    use_resource=use_resource)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 2210, in <lambda>\n    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variable_scope.py\", line 2193, in default_variable_creator\n    constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variables.py\", line 235, in __init__\n    constraint=constraint)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\variables.py\", line 349, in _init_from_args\n    name=name)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\state_ops.py\", line 137, in variable_op_v2\n    shared_name=shared_name)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\ops\\gen_state_ops.py\", line 1254, in variable_v2\n    shared_name=shared_name, name=name)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\framework\\op_def_library.py\", line 787, in _apply_op_helper\n    op_def=op_def)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\framework\\ops.py\", line 3392, in create_op\n    op_def=op_def)\n  File \"C:\\Continuum\\anaconda3\\envs\\deep_ecg\\lib\\site-packages\\tensorflow\\python\\framework\\ops.py\", line 1718, in __init__\n    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access\n\nInvalidArgumentError (see above for traceback): Malformed device specification '//cpu:0'\n\t [[Node: train_op/DeepECG/logits/kernel/Adam_1 = VariableV2[_class=[\"loc:@DeepECG/logits/kernel\"], container=\"\", dtype=DT_FLOAT, shape=[64,3], shared_name=\"\", _device=\"//cpu:0\"]()]]\n"
     ]
    }
   ],
   "source": [
    "# Set save path for graphs, summaries, and checkpoints\n",
    "save_path = r'C:\\Users\\sebastian goodfellow\\Desktop\\tensorboard\\deep_ecg'\n",
    "\n",
    "# Set model name\n",
    "model_name = 'test_1'\n",
    "\n",
    "# Maximum number of checkpoints to keep\n",
    "max_to_keep = 20\n",
    "\n",
    "# Set randome states\n",
    "seed = 0                                    \n",
    "tf.set_random_seed(seed)                      \n",
    "\n",
    "# Get training dataset dimensions\n",
    "(m, length, channels) = x_train.shape  \n",
    "\n",
    "# Get number of label classes\n",
    "classes = y_train_1hot.shape[1]                  \n",
    "\n",
    "# Choose network\n",
    "network_name = 'DeepECG'\n",
    "\n",
    "# Set network inputs\n",
    "network_parameters = dict(\n",
    "    length=length,\n",
    "    channels=channels, \n",
    "    classes=classes, \n",
    "    seed=seed,  \n",
    ")\n",
    "\n",
    "# Create model\n",
    "model = Model(\n",
    "    model_name=model_name, \n",
    "    network_name=network_name, \n",
    "    network_parameters=network_parameters, \n",
    "    save_path=save_path,\n",
    "    max_to_keep=max_to_keep\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# 7. Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Set hyper-parameters\n",
    "epochs = 100\n",
    "minibatch_size =  10\n",
    "learning_rate = 0.001            \n",
    "\n",
    "# Train model\n",
    "train(\n",
    "    model=model, \n",
    "    x_train=x_train, y_train=y_train_1hot, \n",
    "    x_val=x_val, y_val=y_val_1hot,\n",
    "    learning_rate=learning_rate,\n",
    "    epochs=epochs, mini_batch_size=minibatch_size, \n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
